# -*- coding:utf-8 -*-

import tensorflow as tf
from utils.calculate_utils import multi_add

"""
RNN 模块
"""


def cell(state, x, w_aa, w_ax, w_ay, b_aa, b_ay):
    """
    RNN cell 计算单元
    :param state: 隐状态
    :param x: 输入
    :param w_aa: 权重a--> a
    :param w_ax: 权重a--> x
    :param w_ay: 权重a--> y
    :param b_aa: 偏置a--> a
    :param b_ay: 偏置a--> y
    :return:
    """
    state_next = tf.tanh(multi_add(tf.matmul(w_ax, x), tf.matmul(w_aa, state), b_aa))
    out = tf.add(tf.matmul(w_ay, state_next), b_ay)
    return state_next, out


def rnn_cell_custom(x, n, a_0, y_0):
    """
    RNN基本模块cell (自定义)
    :param x: 输入
    :param n: cell个数x
    :param a_0:
    :param y_0:
    :return:
    """
    out = []  # 装载每一步的state 和 y
    x_shpape = x.shape.as_list()  # x的rank_1
    state = tf.Variable(tf.truncated_normal((a_0, x_shpape[1]), stddev=0.1))  # 初始化隐状态
    w_aa = tf.Variable(tf.truncated_normal((a_0, a_0), stddev=0.1))
    w_ax = tf.Variable(tf.truncated_normal((a_0, x_shpape[0]), stddev=0.1))
    w_ay = tf.Variable(tf.truncated_normal((y_0, a_0), stddev=0.1))
    b_aa = tf.Variable(tf.constant(0.1, shape=[a_0, x_shpape[1]]))
    b_ay = tf.Variable(tf.constant(0.1, shape=[y_0, x_shpape[1]]))
    for i in range(n):
        state, y = cell(state, x, w_aa, w_ax, w_ay, b_aa, b_ay)
        out.append([state, y])
    return out


def cell2(x_a, w_a, w_ay, b_aa, b_ay):
    """
    RNN cell 计算单元
    :param x_a: x和a的拼接
    :param w_a:  权重x, a--> a
    :param w_ay: 权重a--> y
    :param b_aa: 偏置a--> a
    :param b_ay: 偏置a--> y
    :return:
    """
    state_next = tf.tanh(tf.add(tf.matmul(w_a, x_a), b_aa))
    out = tf.add(tf.matmul(w_ay, state_next), b_ay)
    return state_next, out


def rnn_cell_concat(x, n, a_0, y_0):
    """
    RNN基本模块cell (自定义)
    :param x: 输入
    :param n: cell个数x
    :param a_0:
    :param y_0:
    :return:
    """
    out = []  # 装载每一步的state 和 y
    x_shpape = x.shape.as_list()
    state = tf.Variable(tf.truncated_normal((a_0, x_shpape[1]), stddev=0.1))  # 初始化隐状态
    x_a = tf.concat([x, state], axis=0)  # 拼合x与a
    w_a = tf.Variable(tf.truncated_normal((a_0, x_a.shape.as_list()[0]), stddev=0.1))
    w_ay = tf.Variable(tf.truncated_normal((y_0, a_0), stddev=0.1))
    b_aa = tf.Variable(tf.constant(0.1, shape=[a_0, x_shpape[1]]))
    b_ay = tf.Variable(tf.constant(0.1, shape=[y_0, x_shpape[1]]))
    for i in range(n):
        state, y = cell2(x_a, w_a, w_ay, b_aa, b_ay)
        out.append([state, y])
    return out


def rnn_cell_1_def(x, a):
    """
    RNN 基本模块, tf版, 调用call执行一步
    :param x:
    :param a:
    :return:
    """
    batch_size = 32
    x_shape = 100
    # tf.contrib.cudnn_rnn.CudnnRNNTanh(num_units=128)
    inputs = tf.placeholder(tf.float32, shape=(batch_size, x_shape))  # 32 是 batch_size
    cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)  # state_size = 128
    h0 = cell.zero_state(batch_size, tf.float32)
    output, h1 = cell.call(inputs, h0)  # 调用call函数
    return output, h1


def get_multi_lstm_cell(lstm_size, keep_prob):
    """
    创建单个cell并堆叠多层
    :param lstm_size:
    :param keep_prob:
    :return:
    """
    lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
    drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    return drop
